##### #############################################
#####                                        ######
#####       Analyse topic model              ######
#####                                        ######
##### #############################################

rm(list=ls())

library(quanteda) # v.3.0.0
library(stm) # v.1.3.6
library(data.table) # v.1.13.6
library(scales) # v1.1.1
library(sandwich) # v.2.5.1

# Load data 

load("data/debates.Rdata") 
load("working/debate_dfm_stm.Rdata")


# Cut 1992-1997 session & speaker

debates <- debates[parliamentary_term != "1992-1997"]
debates <- debates[is_speaker == FALSE]

# Collapse to speaker in debate level

text_by_mp_in_debate <- debates[,list(body = paste0(body, collapse =  " "), 
                                      gender = unique(gender), 
                                      year = unique(year), 
                                      parliamentary_term = unique(parliamentary_term)), 
                                by = list(section_id,person_id)]


affect_list <- list()
posemo_list <- list()
negemo_list <- list()
fact_list <- list()
anecdote_list <- list()
aggression_list <- list()
complexity_list <- list()
repetition_list <- list()
i <- 0

topic_numbers <- seq(10, 80, 10)

for(k in topic_numbers){
  i <- i + 1
load(paste0("working/stm_out/stm_out_",k,".Rdata"))

# Extract topics & merge 

topic_proportions <- data.frame(stm_object$theta) 
topic_labels <- paste0("topic_",1:ncol(topic_proportions))

names(topic_proportions) <- topic_labels
out <- data.frame(docvars(debate_dfm_stm)[rowSums(debate_dfm_stm) != 0,], topic_proportions)

parlimentary_term <- text_by_mp_in_debate[,list(parliamentary_term = unique(parliamentary_term)), 
                                          by = list(person_id, section_id)]

out <- merge(out, parlimentary_term, by = c("person_id","section_id"))


### Merge in style types -----

load("working/speech_scores.Rdata")

speech_scores <- speech_scores[parliamentary_term != "1992-1997"]
speech_scores <- speech_scores[is_speaker == FALSE]

speech_scores <- merge(out, speech_scores, by = c("person_id","section_id"))

model_formula <- paste0(paste0("gender.x * ", topic_labels[-1]), collapse = " + ")

# Models 

affect <- lm(paste0("affect_std ~ ", model_formula), data = speech_scores)
posemo <- lm(paste0("posemo_std ~ ", model_formula), data = speech_scores)
negemo <- lm(paste0("negemo_std ~ ", model_formula), data = speech_scores)
fact <- lm(paste0("fact_std ~ ", model_formula), data = speech_scores)
anecdote <- lm(paste0("anecdote_std ~ ", model_formula), data = speech_scores)
aggression <- lm(paste0("aggression_std ~ ", model_formula), data = speech_scores)
complexity <- lm(paste0("complexity_std ~ ", model_formula), data = speech_scores)
repetition <- lm(paste0("repetition_std ~ ", model_formula), data = speech_scores)


topic_gender_gaps <- function(model = affect, ...){
  
  model_coefs <- coef(model)
  model_summary <- summary(model)
  model_vcov <- vcov(model)
  
  baseline <- model_coefs[2]
  baseline_hi <- baseline + 1.96 * coef(model_summary)[2,2]
  baseline_lo <- baseline - 1.96 * coef(model_summary)[2,2]
  
  # Extract interaction terms
  interaction_terms <- grep("\\:", names(model_coefs))
  
  # Calculate marginal effects
  marginal_effects <- baseline + model_coefs[interaction_terms]

  # Calculate standard errors for sums of coefficients
  marginal_effect_ses <- sapply(1:length(interaction_terms),function(x) sqrt(model_vcov[2,2] + model_vcov[interaction_terms[x],interaction_terms[x]] + (2*model_vcov[2,interaction_terms[x]])))
  
  # Upper and lower confidence intervals
  marginal_effects_hi <- marginal_effects + 1.96 * marginal_effect_ses
  marginal_effects_lo <- marginal_effects - 1.96 * marginal_effect_ses 
  
  # Combine all into a data frame 
  out <- data.frame(topic_labels = topic_labels,
                    est = c(baseline, marginal_effects),
                    hi = c(baseline_hi, marginal_effects_hi),
                    lo = c(baseline_lo, marginal_effects_lo),
                    z = c(baseline/coef(model_summary)[2,2], marginal_effects/marginal_effect_ses),
                    row.names = NULL)
  
  out <- out[order(out$z, decreasing = F),]
  out$y <- 1:nrow(out)
  
  return(out) 
}



affect_diffs <- topic_gender_gaps(affect, main = "Affect")
negemo_diffs <- topic_gender_gaps(negemo, main = "Negative Emotion")
posemo_diffs <- topic_gender_gaps(posemo, main = "Positive Emotion")
fact_diffs <- topic_gender_gaps(fact, main = "Fact")
anecdote_diffs <- topic_gender_gaps(anecdote, main = "Human Narrative")
aggression_diffs <- topic_gender_gaps(aggression, main = "Aggression")
complexity_diffs <- topic_gender_gaps(complexity, main = "Complexity")
repetition_diffs <- topic_gender_gaps(repetition, main = "Repetition")


## Predict gender gap for each topic for each style

speech_scores <- data.table(speech_scores)
words_topic_month <- data.frame(speech_scores[,lapply(.SD, function(x) sum(x * n_words)),by = yearmon,.SDcols = paste0("topic_",1:length(topic_labels))])

words_topic_month <- reshape2::melt(words_topic_month, id.vars = c("yearmon"))

words_topic_month$affect_gap <- affect_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z
words_topic_month$negemo_gap <- negemo_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z
words_topic_month$posemo_gap <- posemo_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z
words_topic_month$fact_gap <- fact_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z
words_topic_month$anecdote_gap <- anecdote_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z
words_topic_month$aggression_gap <- aggression_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z
words_topic_month$complexity_gap <- complexity_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z
words_topic_month$repetition_gap <- repetition_diffs[match(words_topic_month$variable, repetition_diffs$topic_labels),]$z

z <- rep(NA, length(topic_labels))
for(t in 1:length(topic_labels)) z[t] <- coef(summary(lm(value ~ yearmon, words_topic_month, subset = words_topic_month$variable == topic_labels[t])))[2,3]


affect_diffs$z_time <- z[match(affect_diffs$topic_labels, topic_labels)]
posemo_diffs$z_time <- z[match(posemo_diffs$topic_labels, topic_labels)]
negemo_diffs$z_time <- z[match(negemo_diffs$topic_labels, topic_labels)]
fact_diffs$z_time <- z[match(fact_diffs$topic_labels, topic_labels)]
anecdote_diffs$z_time <- z[match(anecdote_diffs$topic_labels, topic_labels)]
aggression_diffs$z_time <- z[match(aggression_diffs$topic_labels, topic_labels)]
complexity_diffs$z_time <- z[match(complexity_diffs$topic_labels, topic_labels)]
repetition_diffs$z_time <- z[match(repetition_diffs$topic_labels, topic_labels)]
var <- "affect"

plot_gap_vs_time <- function(var,...){
plot(get(paste0(var,"_diffs"))$z, get(paste0(var,"_diffs"))$z_time, bty = "n", pch = 19, ylab = "Effect of time on topic prevalence", xlab = "Gender gap on topic", ...)
abline(v = 0, lty = 3)
abline(h = 0, lty = 3)

model <- lm(z_time ~ z, data = get(paste0(var,"_diffs")))
abline(model, 
       col = ifelse(coef(summary(model))[2,4] < 0.05, "red", "gray"), 
       lty = ifelse(coef(summary(model))[2,4] < 0.05, 1, 2), 
       lwd = 2)
return(model)

}




pdf(paste0("analysis/plots/topic/gap_vs_time_",k,".pdf"),12,6)
par(mfrow = c(2,4))
affect_list[[i]] <- plot_gap_vs_time("affect", main = "Affect")
posemo_list[[i]] <- plot_gap_vs_time("posemo", main = "Positive Emotion")
negemo_list[[i]] <- plot_gap_vs_time("negemo", main = "Negative Emotion")
fact_list[[i]] <- plot_gap_vs_time("fact", main = "Fact")
anecdote_list[[i]] <- plot_gap_vs_time("anecdote", main = "Human Narrative")
aggression_list[[i]] <- plot_gap_vs_time("aggression", main = "Aggression")
complexity_list[[i]] <- plot_gap_vs_time("complexity", main = "Complexity")
repetition_list[[i]] <- plot_gap_vs_time("repetition", main = "Repetition")
dev.off()

}


plot_coef_list <- function(tmp_list = affect_list,...){
  x <- lapply(tmp_list, function(x) {
  tmp <- coef(summary(x))
  est <- tmp[2,1]
  se <- tmp[2,2]
  hi <- est + 1.96 * se
  lo <- est - 1.96 * se
  sig <- ifelse(abs(est/se) >= 1.96, TRUE, FALSE)
  data.frame(est,se,hi,lo, sig)
  }
  )
x <- do.call("rbind", x)
x$k <- topic_numbers
plot(x = x$k, y = x$est, pch = ifelse(x$sig,19,19),
     xlab = "Number of topics", 
     col = ifelse(x$sig,"black",alpha("black", .2)), 
     ylab = "Estimate", bty = "n", cex = 1.2, ylim = c(-2, 2),...)
segments(x0 = x$k, y0 = x$lo, y1 = x$hi, col = ifelse(x$sig,"black",alpha("black", .2)), lwd = 2)
abline(h = 0, lty = 2)
}

pdf(paste0("analysis/plots/topic/gap_vs_time_all_topics.pdf"),12,6)
par(mfrow = c(2,4))
plot_coef_list(affect_list, main = "Affect")
plot_coef_list(posemo_list, main = "Positive Emotion")
plot_coef_list(negemo_list, main = "Negative Emotion")
plot_coef_list(fact_list, main = "Fact")
plot_coef_list(anecdote_list, main = "Human Narrative")
plot_coef_list(aggression_list, main = "Aggression")
plot_coef_list(complexity_list, main = "Complexity")
plot_coef_list(repetition_list, main = "Repetition")
dev.off()